% Finds CitcomS .cap, .opt, and .tracer files created by CitcomS's
% autocombine.py routine and plots cross sections of selected datasets
%
% Created by M. J. Jones
% Last modified May 2024 by M. J. Jones

clear
clf

%**************************************************************************

% ///Changeable variables///
% //Script parameters//
% ****IMPORTANT: base directory setup -- SET TO BASE LOCATION OF DATA FILES****
base_dir = "~/Downloads/"; 

% Model to plot
model_ext = "deep"; %options are: "deep", "shallow", "none"

% Data to plot
% `plot_types`: 0 for temperature, 1 for viscosity, 2 for velocity, 3 for
%   late-stage cumulates only, 4 for Mg-suite only, 5 for composition overlay (LSC & Mg-suite)
plot_types = [0, 1, 2, 5, 2, 2];
% `velo_types`: 0 for radial velocity, 1 for colatitude velocity, 2 for longitude velocity, 
%   3 for cross-sectional velocity w/ vector quivers, NaN if not corresponding to velocity;
%   for use with plot_types=[..., 2, ...]
velo_types = [NaN, NaN, 3, NaN, 0, 1];

% Timestep selection
pick_steps = 1; %1 to select time steps with `timestep_array`, 0 to plot all time steps
autoselect = 0; %0 to turn off, 1 to select only time steps in Figure 2 of manuscript plus final step of simulation
timestep_array = 0; %integer scalar or vector to select particular time steps


% //Figure parameters//
% General settings
save_figs = 0; %1 to save figures, 0 to not save
save_dir = "~/Documents/Jones2024_CitcomS-Plots/"; % absolute or relative path to directory for saving images

darkmode =      0; %whether to plot with black background and white annotation
format_style =  0; %0: labeled plots for animation; 1: blank plots for figures
show_time =     1; %whether to plot time label
show_progress = 1; %whether to plot time progress bar


% //Colormap settings//
n_colors = 512;

% Colormap limits
clim_LSC = [0 1]; %late-stage cumulates [volume fraction]
clim_mg = clim_LSC; %Mg-suite [volume fraction]
clim_T = [1000 1900]; %temperature [K]
clim_visc = [18 22]; %viscosity [log(Pa s)]
clim_vr = [-5 5]; %radial velocity [cm/yr]
clim_vlat = clim_vr; %colatitude velocity [cm/yr]
clim_vlon = clim_vr; %longitude velocity [cm/yr]
clim_v = [0 5]; %cross-sectional velocity [cm/yr]




% ///Variables that may break the script if changed///
% General settings
slice_thres = 0.01 * pi; %[see code for usage]
ncomp = 1; %do not change
plotcomp = 1; %do not change
Mgsuite_max_depth = 700; %km -- maximum depth for "     "
Mgsuite_min_depth = Mgsuite_max_depth/2; %km -- minimum depth for initial position of Mg-suite parent material

% Data processing
ncaps = 12; %this stays at 12 because this file works with CitcomS autocombined files
nproc = 120; %number of processors (for plotting tracers; `ncaps` is for all other data)
grid_dim = 500; %number of points per side for interpolation grid used for plotting
grid_dim_trc = grid_dim; %same as `grid_dim`, but for tracer-based composition plotting

% Plot settings
cbar_lw = 2; %linewidth of colorbars (default is 0.5)
border_lw = 1.75; %linewidth of circular plot borders
border_color = [.5 .5 .5];
vector_dim = 25; %scaling factor for velocity quiver size

ncol = 2; %number of columns of subplots
sp_size = 500; %length of side of a subplot, in px
fontsize = 24;
fig_position = 0; %horizontal position of plot window on screen, in px
max_fig_H = 1300; %max figure height in pixels to prevent squishing subplots


%**************************************************************************
% //Automated settings//
% Filename and directory setup
suite_name = "moon_suite_23_06_26";
grid_resolution = "";
input_filename = suite_name + "_" + model_ext;
if model_ext == "shallow"
    Mgsuite_max_depth = 700; %km -- maximum depth for initial position of Mg-suite parent material
elseif model_ext == "deep"
    Mgsuite_max_depth = 1000; %km -- maximum depth for "    "
elseif model_ext == "none"
    Mgsuite_max_depth = (1700 - 330); %km -- maximum depth for "    "
end
Mgsuite_min_depth = Mgsuite_max_depth/2; %km -- minimum depth for initial position of Mg-suite parent material
directory = base_dir + input_filename + "/";

% Timestep setup (if automated)
if pick_steps && autoselect
    if model_ext == "shallow"
        timestep_array = [0 5 360 775 1380];
    elseif model_ext == "deep"
        timestep_array = [0 5 265 755 1320];
    elseif model_ext == "none"
        timestep_array = [0 5 560 1180];
    end
end

% Figure title
figure_title = input_filename + "_" + grid_resolution;
for ii = 1:2
    if plot_types(ii) == 0
        figure_title = figure_title + "temp";
    elseif plot_types(ii) == 1
        figure_title = figure_title + "visc";
    elseif plot_types(ii) == 2
        if velo_types(ii) == 0
            figure_title = figure_title + "vr";
        elseif velo_types(ii) == 1
            figure_title = figure_title + "vlat";
        elseif velo_types(ii) == 2
            figure_title = figure_title + "vlon";
        elseif velo_types(ii) == 3
            figure_title = figure_title + "v";
        else
            error("Error: Variable 'velo_types' is not set correctly.")
        end
    elseif plot_types(ii) == 3
        figure_title = figure_title + "comp";
    elseif ismember(plot_types(ii), [4 5])
        figure_title = figure_title + "Mg";
    else
        error("Error: Variable 'plot_types' is not set correctly.")
    end
end
figure_title = figure_title + "_";
if darkmode
    figure_title = figure_title + "dark_";
end

% Dimensionalization parameters
eta0 = 1e20; %Pa s -- reference viscosity
Tsurf = 1600; %K -- surface temperature
Tcmb = 1850; %K -- core-mantle boundary temperature
Rbody = 1700; %km -- outer radius of model
Rcore = 330; %km -- inner radius of model
kappa = 1e-6; %m^2/s -- thermal diffusivity
Tref = Tcmb - Tsurf; %K -- reference temperature (ΔT)
Rcore_nd = Rcore / Rbody; %dimensionless -- inner radius of model
scalev = Rbody * 1000 / (kappa * 100 * 365.25 * 24 * 3600); %conversion factor for velocity

% Read in time info
time_array = dlmread(directory+input_filename+".MJtime", ' ', 0, 0);
time_array(:, 3) = [];
if pick_steps
    time_array = time_array(ismember(time_array(:, 1), timestep_array), :);
else
    timestep_array = time_array(:, 1);
end
max_step = time_array(end, 1);
max_time = time_array(end, 2);

% Core configuration
theta = 0:0.01:2*pi;
core_x = Rcore_nd*cos(theta);
core_y = Rcore_nd*sin(theta);

% Mg-suite configuration
Mgsuite_min = 1 - (Mgsuite_max_depth/Rbody); %dimensionless -- minimum RADIUS of Mgsuite initial position
Mgsuite_max = 1 - (Mgsuite_min_depth/Rbody); %dimensionless -- maximum RADIUS of Mgsuite initial position

% Colormap definitions
%KREEP colormap (for C_KREEP > 0)
comp_cmap = cool(n_colors);
comp_cmap(1, :) = [.7 .7 .7];

%Mg-suite colormap
mg_cmap = flipud(summer(n_colors));
mg_cmap(1, :) = [.7 .7 .7];

%temperature colormap
nrows = floor(n_colors * 0.4);
blk2blu = zeros(nrows, 3);
blk2blu(:, 3) = linspace(0, 0.5078, nrows);
temp_cmap = [blk2blu; jet(n_colors-nrows)];

%velocity colormap
cmax = [1 0 1];
cmid = [.95 .95 .95];
cmin = [0 1 1];
velo_cmap = generate_custom_cmap(n_colors, 3, [cmin; cmid; cmax]', [0 round(n_colors/2) n_colors]);

%viscosity colormap
visc_cmap = flipud(turbo(512));

% Darkmode settings
if darkmode
    textcolor = [1 1 1];
    bgcolor = [0 0 0];
    corecolor = [0 0 0];
    invertsetting = 'off';
else
    textcolor = [0 0 0];
    bgcolor = [1 1 1];
    corecolor = [1 1 1];
    invertsetting = 'on';
end

% Misc. plot settings
%adjust core linewidth so that it doesn't cover up data
if border_lw > 1.5
    border_lw_core = 1.5;
else
    border_lw_core = border_lw;
end

% Make circle border for plots
circ_ang = linspace(0, 2*pi, 100);
circx = sin(circ_ang); % .* 0.993;
circy = cos(circ_ang); % .* 0.993;
circx_core = circx .* Rcore_nd;
circy_core = circy .* Rcore_nd;

% Set plot dimensions
nplot = length(plot_types);
nrow = ceil(nplot / ncol);
sp_W0 = 1 / ncol; %subplot width
sp_H0 = 1 / nrow; %subplot height
fig_W = sp_size * ncol; %width of figure
fig_H = sp_size * nrow; %height of figure
if fig_H > max_fig_H
    f_rescale = max_fig_H / fig_H; %factor for reducing scale of subplots
    fig_W = fig_W * f_rescale;
    fig_H = fig_H * f_rescale;
    warning("Figure size has been reduced to prevent change in subplot aspect. Ignoring `sp_size`.")
end

% Set datafile prefixes
cap_exts = strings(ncaps, 1);
capIDs = 0:ncaps-1;
for ii = 1:ncaps
    cap_exts(ii) = sprintf("%02d", capIDs(ii));
end
capfiles_template = strcat(repmat(directory+input_filename+".cap", ncaps, 1), cap_exts, repmat(".", ncaps, 1));
optfiles_template = strcat(repmat(directory+input_filename+".opt", ncaps, 1), cap_exts, repmat(".", ncaps, 1));

trc_exts = string(0:nproc-1)';
trcfiles_template = strcat(repmat(directory+input_filename+".tracer.", nproc, 1), trc_exts, repmat(".", nproc, 1));



% //Main script//
% Print a progress message to the console
fprintf("Timesteps completed (%s):\n", input_filename)

% Loop through all autocombined files, plot the specified number of processors/wedges at each time step
for tID = 1:length(timestep_array)
    % Update time variables
    tstep = time_array(tID, 1);
    tstep_str = string(tstep);
    tMyr = time_array(tID, 2);
    timestep_int = round(str2double(tstep));
    if max_time > 0
        time_progress = tMyr / max_time;
    else
        time_progress = 0;
    end

    % Set current filenames
    capfiles = strcat(capfiles_template, repmat(tstep_str, ncaps, 1));
    optfiles = strcat(optfiles_template, repmat(tstep_str, ncaps, 1));
    trcfiles = strcat(trcfiles_template, repmat(tstep_str, nproc, 1));

    clf %clear figure if there's more data to plot
    
    %%% //Read in .cap and .opt files for this timestep, starting by reserving enough array space//
    capfile = dlmread(capfiles(1), ' ', 1, 0);
    optfile = dlmread(optfiles(1), ' ', 1, 0);
    
    nd_per_cap = length(capfile(:, 1));
    allocation = zeros((nd_per_cap*(ncaps-1)), 1);
    colat = [capfile(:, 1); allocation];
    lon = [capfile(:, 2); allocation];
    rel_radius = [capfile(:, 3); allocation];
    vel_colat = [capfile(:, 4); allocation];
    vel_lon = [capfile(:, 5); allocation];
    vel_r = [capfile(:, 6); allocation];
    temp = [capfile(:, 7); allocation];
    visc = [capfile(:, 8); allocation];
    
    comp_allocation = zeros((nd_per_cap*(ncaps-1)), ncomp);
    comp = [optfile(:, 1:ncomp); comp_allocation];
    
    % Collect all data points for full sphere into vectors
    for capID = 2:ncaps
        capfile = dlmread(capfiles(capID), ' ', 1, 0);
        optfile = dlmread(optfiles(capID), ' ', 1, 0);
        
        current_IDs = ((nd_per_cap*capID)+1):(nd_per_cap*(capID+1));
        colat(current_IDs) = capfile(:, 1);
        lon(current_IDs) = capfile(:, 2);
        rel_radius(current_IDs) = capfile(:, 3);
        vel_colat(current_IDs) = capfile(:, 4);
        vel_lon(current_IDs) = capfile(:, 5);
        vel_r(current_IDs) = capfile(:, 6);
        temp(current_IDs) = capfile(:, 7);
        visc(current_IDs) = capfile(:, 8);
        
        comp(current_IDs, :) = optfile(:, 1:ncomp);
    end
    comp = 1.0 - comp;
    clearvars capfile optfile current_capfile current_optfile allocation comp_allocation all_cap all_opt current_IDs
    
    %%% **cap and opt arrays are now fully loaded**

    %%% //Read in .tracer files for this timestep [not sure if ntracers is constant, esp. per processor, so don't allocate]//
    trc_colat = double.empty;
    trc_lon = double.empty;
    trc_r = double.empty;
    trc_flav = int8.empty;
    trc_OGcolat = double.empty;
    trc_OGlon = double.empty;
    trc_OGr = double.empty;
    for ff = 1:length(trcfiles)
        trcfile = dlmread(trcfiles(ff), ' ', 1, 0);

        trc_colat = [trc_colat; trcfile(:, 1)];
        trc_lon = [trc_lon; trcfile(:, 2)];
        trc_r = [trc_r; trcfile(:, 3)];
        trc_flav = [trc_flav; int8(trcfile(:, 4))];
        trc_OGcolat = [trc_OGcolat; trcfile(:, 5)];
        trc_OGlon = [trc_OGlon; trcfile(:, 6)];
        trc_OGr = [trc_OGr; trcfile(:, 7)];
    end

    % Convert spherical coordinates produced by CitcomS to Cartesian
    [coord_x, coord_y, coord_z] = sph2cart(lon, colat-(pi/2), rel_radius);
    [trc_x, trc_y, trc_z] = sph2cart(trc_lon, trc_colat-(pi/2), trc_r);

    
    %%% //Process data//
    % Create a cross-section
    IDs_insection = (mod(lon, pi) <= slice_thres/2) | (mod(lon, pi) >= (pi - slice_thres/2));
    IDs_insection_trc = abs(trc_y) <= (slice_thres * (grid_dim_trc / 100)); %**NOTE**: this is poorly designed; `grid_dim_trc` == 500 when `slice_thres` == .01*pi means slice depth is 15% of lunar radius
    
    % Cull data points to only those within the desired plane
    colat = colat(IDs_insection);
    lon = lon(IDs_insection);
    rel_radius = rel_radius(IDs_insection);
    vel_colat = vel_colat(IDs_insection);
    vel_lon = vel_lon(IDs_insection);
    vel_r = vel_r(IDs_insection);
    temp = temp(IDs_insection);
    visc = visc(IDs_insection);
    coord_x = coord_x(IDs_insection);
    coord_y = coord_y(IDs_insection);
    coord_z = coord_z(IDs_insection);
    
    comp = comp(IDs_insection, :);

    clearvars IDs_insection
    
    trc_colat = trc_colat(IDs_insection_trc);
    trc_lon = trc_lon(IDs_insection_trc);
    trc_r = trc_r(IDs_insection_trc);
    trc_flav = trc_flav(IDs_insection_trc);
    trc_OGcolat = trc_OGcolat(IDs_insection_trc);
    trc_OGlon = trc_OGlon(IDs_insection_trc);
    trc_OGr = trc_OGr(IDs_insection_trc);
    trc_x = trc_x(IDs_insection_trc);
    trc_y = trc_y(IDs_insection_trc);
    trc_z = trc_z(IDs_insection_trc);

    clearvars IDs_insection_trc
    
    % Get rid of repeated coordinates (from overlapping of processors)
    coords = [colat'; lon'; rel_radius']';
    [~, unq_coordID] = unique(coords, 'row', 'stable');
    
    colat = colat(unq_coordID);
    lon = lon(unq_coordID);
    rel_radius = rel_radius(unq_coordID);
    vel_colat = vel_colat(unq_coordID);
    vel_lon = vel_lon(unq_coordID);
    vel_r = vel_r(unq_coordID);
    temp = temp(unq_coordID);
    visc = visc(unq_coordID);
    
    comp = comp(unq_coordID, :);

    clearvars unq_coordID

    coords_trc = [trc_colat'; trc_lon'; trc_r']';
    [~, trc_unqID] = unique(coords_trc, 'row', 'stable');

    trc_colat = trc_colat(trc_unqID);
    trc_lon = trc_lon(trc_unqID);
    trc_r = trc_r(trc_unqID);
    trc_flav = trc_flav(trc_unqID);
    trc_OGcolat = trc_OGcolat(trc_unqID);
    trc_OGlon = trc_OGlon(trc_unqID);
    trc_OGr = trc_OGr(trc_unqID);

    clearvars trc_unqID
    



%     % Convert spherical coordinates produced by CitcomS to Cartesian
%     [coord_x, coord_y, coord_z] = sph2cart(lon, colat-(pi/2), rel_radius);
%     [trc_x, trc_y, trc_z] = sph2cart(trc_lon, trc_colat-(pi/2), trc_r);

    % Generate a baseline data interpolator for the specified grid
    if ~exist('reg_interp', 'var')
        reg_interp = scatteredInterpolant(coord_x, coord_z, ones(size(coord_x)));
    end
    
    % Create a grid of points and get rid of points outside of cross-section
    xlin = linspace(-1, 1, grid_dim);
    zlin = linspace(-1, 1, grid_dim);
    [X, Z] = meshgrid(xlin, zlin);
    in_domain = (sqrt(X.^2 + Z.^2) <= 1) & (sqrt(X.^2 + Z.^2) >= Rcore_nd);
    X(~in_domain) = NaN;
    Z(~in_domain) = NaN;
    
    % //Plotting//
    Px = 0; %horizontal position of subplot (sets pos. of bottom left corner)
    Py0 = 1 - sp_H0; %vertical position of subplot (sets pos. of bottom left corner)
    % Loop through subplots/data types
    for ii = 1:nplot
        % Set data interpolator
        if ismember(plot_types(ii), [0 1 2 3])
            interpolator = reg_interp;
        end

        % Misc. settings
        sp_W = sp_W0;
        sp_H = sp_H0;
        ii_even = ~mod(ii, 2); %bool for whether current plot is on left or right side

        % Set subplot positioning
        sbplt = subplot(nrow, ncol, ii);
        
        %horizontal position
        if ii_even; Px = 0.5; else; Px = 0; end

        %vertical position (changes for each row)
        if (ii > 1) && ~ii_even; Py0 = Py0 - sp_H0; end
        Py = Py0;

        %modify position for final plot if nplots is odd
        if (ii == nplot) && ~ii_even; Px = 0.25; end

        %reduce subplot size if using colorbars
        if format_style == 0
            sp_pad = (0.075 / nrow) * 1.2; %white space padding on either side of subplots
            sp_W = sp_W0 - (2 * sp_pad);
            sp_H = sp_H0 - (2 * sp_pad);
            Px = Px + sp_pad;
            Py = Py + sp_pad;
        end

        %set position
        set(sbplt, 'Position', [Px Py sp_W sp_H])


        % Set data and corresponding plot parameters
        if plot_types(ii) == 0 %temperature
            data = (temp .* Tref) + Tsurf; %K
            cmap = temp_cmap;
            clims = clim_T;
            cbar_lbl = 'Temperature (K)';
        elseif plot_types(ii) == 1 %viscosity
            data = log10(visc) + log10(eta0); %log(Pa s)
            cmap = visc_cmap;
            clims = clim_visc;
            cbar_lbl = 'Viscosity [log(Pa s)]';
        elseif plot_types(ii) == 2
            cmap = velo_cmap;
            if velo_types(ii) == 0 %radial velocity
                data = vel_r ./ scalev; %cm/yr
                cbar_lbl = 'u{_{r}} (cm/yr)';
                clims = clim_vr;
            elseif velo_types(ii) == 1 %latitudinal velocity
                data = vel_colat ./ scalev; %cm/yr
                cbar_lbl = 'u{_{lat}} (cm/yr)';
                clims = clim_vlat;
            elseif velo_types(ii) == 2 %longitudinal velocity
                data = vel_lon ./ scalev; %cm/yr
                cbar_lbl = 'u{_{lon}} (cm/yr)';
                clims = clim_vlon;
            elseif velo_types(ii) == 3 %cross-sectional velocity
                [dataU, dataV, dataW] = convert_vector_sph2cart(rel_radius, colat, lon, vel_r./scalev, vel_colat./scalev, vel_lon./scalev); %cm/yr -- x-, y-, and z-velocity
                data = sqrt(dataU.^2 + dataV.^2 + dataW.^2); %cm/yr
                cbar_lbl = 'u (cm/yr)';
                clims = clim_v;
                cmap = turbo(n_colors);
            end
        elseif plot_types(ii) == 3 %late-stage cumulate concentration
            data_idx = (trc_flav == 0); %create mask for desired data
            xlin_trc = linspace(-1, 1, grid_dim_trc);
            zlin_trc = linspace(-1, 1, grid_dim_trc);
            trc_count = hist3([trc_x, trc_z], 'Ctrs', {xlin_trc, zlin_trc}); %count all tracers in each bin
            [data, data_pos] = hist3([trc_x(data_idx) trc_z(data_idx)], 'Ctrs', {xlin_trc, zlin_trc}); %count only data tracers in each bin
            data = data ./ trc_count; %volume fraction

            cmap = comp_cmap;
            clims = clim_LSC;
            cbar_lbl = 'Late-stage cumulate fraction';
        elseif ismember(plot_types(ii), [4 5]) %Mg-cumulate concentration
            data_idx = (trc_OGr >= Mgsuite_min) & (trc_OGr <= Mgsuite_max); %create mask for desired data
            xlin_trc = linspace(-1, 1, grid_dim_trc);
            zlin_trc = linspace(-1, 1, grid_dim_trc);
            trc_count = hist3([trc_x, trc_z], 'Ctrs', {xlin_trc, zlin_trc}); %count all tracers in each bin
            [data, data_pos] = hist3([trc_x(data_idx) trc_z(data_idx)], 'Ctrs', {xlin_trc, zlin_trc}); %count only data tracers in each bin
            data = data ./ trc_count; %volume fraction
            cmap = mg_cmap;
            clims = clim_mg;
            cbar_lbl = 'Mg-cumulate fraction';

            if plot_types(ii) == 5
                lsc_idx = (trc_flav == 0);
                data_lsc = hist3([trc_x(lsc_idx) trc_z(lsc_idx)], 'Ctrs', {xlin_trc, zlin_trc});
                data_lsc = data_lsc ./ trc_count;
            end
        end

        % Create a meshed approximation of CitcomS data
        if ismember(plot_types(ii), [0 1 2])
            interpolator.Values = data;
            Y = interpolator(X, Z);
        elseif ismember(plot_types(ii), [3 4 5])
            if length(xlin_trc) == grid_dim
                [X, Z] = meshgrid(xlin_trc, zlin_trc);
                in_domain = (sqrt(X.^2 + Z.^2) <= 1) & (sqrt(X.^2 + Z.^2) >= Rcore_nd);
                X(~in_domain) = NaN;
                Z(~in_domain) = NaN;
    
                Y = data;

                if plot_types(ii) == 5
                    Y_lsc = rot90(data_lsc, -1);
                end
            else
                [Xtrc, Ztrc] = meshgrid(xlin_trc, zlin_trc);

                trc_interp = scatteredInterpolant(reshape(Xtrc, [], 1), reshape(Ztrc, [], 1), reshape(data, [], 1));
                interpolator = trc_interp;
    
                Y = interpolator(X, Z);

                if plot_types(ii) == 5
                    interpolator.Values = reshape(data_lsc, [], 1);
                    Y_lsc = rot90(interpolator(X, Z), -1); %for some reason, data need to be rotated by 90°
                end
            end
            Y = rot90(Y, -1);
        end
        
        % Plot data
        if plot_types(ii) == 5
            ax1 = gca;
            axpos = ax1.Position;
            h = pcolor(ax1, X, flipud(Z), Y);
            hold on
            
            ax2 = axes;
            h2 = pcolor(ax2, X, flipud(Z), Y_lsc);
            set(h2, 'LineStyle', 'none')

            ax = [ax1, ax2];

            linkaxes(ax)
            set(ax, 'Position', axpos);
        else
            ax = gca;
            h = pcolor(X, flipud(Z), Y);
        end
        set(h, 'LineStyle', 'none')
        hold on

        % If plotting cross-sectional velocity, plot vectors
        if (plot_types(ii) == 2) && (velo_types(ii) == 3)
            % Make coarse grid to avoid tiny vectors
            ugrid_x = linspace(-1, 1, vector_dim);
            ugrid_y = linspace(-1, 1, vector_dim);
            [ugridX, ugridY] = meshgrid(ugrid_x, ugrid_y);
            in_domain = (sqrt(ugridX.^2 + ugridY.^2) <= 1) & (sqrt(ugridX.^2 + ugridY.^2) >= Rcore_nd);
            ugridX(~in_domain) = NaN;
            ugridY(~in_domain) = NaN;
            
            % Interpolate values on coarse grid
            interpolator.Values = dataU;
            Ux = interpolator(ugridX, ugridY);
            interpolator.Values = dataW;
            Uz = interpolator(ugridX, ugridY);

            % Plot vectors
            quiver(ugridX, flipud(ugridY), Ux, Uz, ...
                'LineWidth', 1, 'Color', [1 0 1])
        end

        % Set color properties
        if plot_types(ii) == 5
            colormap(ax1, mg_cmap)
            clim(ax1, clims)

            alphas = ones([n_colors, 1]);
            alphas(1) = 0;
            colormap(ax2, comp_cmap)
            alphamap(ax2, alphas)
            alpha(h2, 'color')
            clim(ax2, clims)
        else
            colormap(gca, cmap)
            clim(clims)
        end

        % Add colorbar (if enabled by `format_style` setting)
        if format_style == 0
            cbar = colorbar(ax(1), 'southoutside', ...
                'FontSize', fontsize*0.7, 'TickLength', 0.02);
            cbar.Label.String = cbar_lbl;
            cbar.Color = textcolor;
            set(cbar, 'LineWidth', cbar_lw)

            %adjust colorbar position to preserve size, reduce whitespace, and prevent clipping
            cb_adjust = sp_pad * 0.8;
            cbar.Position(2) = Py + (cb_adjust/nrow)*0.8 + (nrow+1)/fig_H;
            cbar.Position(4) = 15 / fig_H;
            set(sbplt, 'Position', [Px Py+cb_adjust sp_W sp_H])

            if plot_types(ii) == 5
                %add second colorbar and adjust colorbar positions
                cbar_w = cbar.Position(3);
                cbar.Position(3) = cbar_w * 0.45;

                cbar2_x = cbar.Position(1) + cbar_w - cbar.Position(3);
                cbar2_y = cbar.Position(2);
                cbar2_w = cbar.Position(3);
                cbar2_h = cbar.Position(4);
                cbar2 = colorbar(ax2, 'southoutside', ...
                    'FontSize', fontsize*0.7, 'TickLength', 0.02, ...
                    'Position', [cbar2_x, cbar2_y, cbar2_w, cbar2_h]);
                cbar2.Label.String = {'Late-stage'; 'cumulate fraction'};
                cbar2.Color = textcolor;
                set(cbar2, 'LineWidth', cbar_lw)
            end
        end

        % Set axis properties
        xlim(ax, [-1, 1])
        ylim(ax, [-1, 1])
        pbaspect(ax(1), [1 1 1])
        if plot_types(ii) == 5
            pbaspect(ax(2), [1 1 1])
            set(ax(2), 'Position', ax1.Position)
        end
        set(ax, 'FontSize', fontsize)
        set(ax, 'visible', 'off')
        set(ax, 'Color', bgcolor)

        % Plot a circle representing the core
        core = patch(core_x, core_y, corecolor);
        set(core, 'LineStyle', 'none')
    
        % Plot a circular border
        plot(circx, circy, 'Color', border_color, 'LineWidth', border_lw)
        plot(circx_core, circy_core, 'Color', border_color, 'LineWidth', border_lw_core)
        
        hold off
    end
    
    % Plot time and progress bar
    txt_H_px = 20; %textbox height in pixels
    txt_H = txt_H_px / fig_H; %convert textbox height to figure-normalized units
    if show_time && ~(format_style == 1)
        time_lbl = sprintf("%3.0f Myr", round(tMyr));
        txt_y = 1 - txt_H;
        annotation('textbox', [0.2 txt_y 0.6 txt_H], ...
            'String', time_lbl, 'FontSize', fontsize, 'Color', textcolor, ...
            'EdgeColor', 'None', 'HorizontalAlignment', 'center')
    end
    if show_progress && ~(format_style == 1)
        bar_W = 0.2 + (0.04 * (3-nrow));
        bar_H_px = 20;
        bar_H = bar_H_px / fig_H;
        bar_x = (1 - bar_W) / 2;
        bar_y = 1 - ( (txt_H_px + 13)/fig_H + bar_H);
        annotation('rectangle', [bar_x bar_y bar_W*time_progress bar_H], 'FaceColor', [.6 .6 .6], 'EdgeColor', 'None')
        annotation('rectangle', [bar_x bar_y bar_W bar_H], 'LineWidth', 2, 'EdgeColor', textcolor)
    end
    
    hold off
    
    % Set figure properties
    fig = gcf;
    set(fig, 'position', [fig_position, 81, fig_W, fig_H])
    set(fig, 'InvertHardCopy', invertsetting)
    set(fig, 'Color', bgcolor)
    set(fig, 'MenuBar', 'none')
    
    % Print a progress message to the console and pause script to update console and figure
    fprintf("%.0f ", tstep)
    if (mod(tID, 20) == 0)
        fprintf("\n")
    end
    pause(0.001)
    
    if save_figs
        saveas(gcf, save_dir+figure_title+tstep+"_"+round(tMyr)+"Myr.png");
    end
end

fprintf("\n")







